package examples.Sudoku where

import Data.TreeMap (Tree, keys)
import Data.List as DL hiding (find, union)


type Element    = Int           -- 1,2,3,4,5,6,7,8,9
type Zelle      = [Element]     -- set of candidates
type Position   = Int           -- 0..80
type Feld       = (Position, Zelle)
type Brett      = [Feld]

--- data type for assumptions and conclusions
data Assumption =
              !ISNOT Position Element
            | !IS    Position Element


derive Eq Assumption
derive Ord Assumption
instance Show Assumption where
    show (IS p e)    = pname p ++ "=" ++ e.show
    show (ISNOT p e) = pname p ++ "/" ++ e.show

showcs cs = joined " " (map Assumption.show cs)

elements :: [Element]           -- all possible elements
elements = [1 .. 9]

{-
    a  b  c   d  e  f   g  h  i
     0  1  2 | 3  4  5 | 6  7  8    1
     9 10 11 |12 13 14 |15 16 17    2
    18 19 20 |21 22 23 |24 25 26    3
    ---------|---------|--------
    27 28 29 |30 31 32 |33 34 35    4
    36 37 38 |39 40 41 |42 43 44    5
    45 46 47 |48 49 50 |51 52 53    6
    ---------|---------|--------
    54 55 56 |57 58 59 |60 61 62    7
    63 64 65 |66 67 68 |69 70 71    8
    72 73 74 |75 76 77 |78 79 80    9
-}

positions :: [Position]         -- all possible positions
positions = [0..80]
rowstarts :: [Position]         -- all positions where a row is starting
rowstarts =  [0,9,18,27,36,45,54,63,72]
colstarts :: [Position]         -- all positions where a column is starting
colstarts =  [0,1,2,3,4,5,6,7,8]
boxstarts :: [Position]         -- all positions where a box is starting
boxstarts =  [0,3,6,27,30,33,54,57,60]
boxmuster :: [Position]         -- pattern for a box, by adding upper left position results in real box
boxmuster =  [0,1,2,9,10,11,18,19,20]


--- extract field for position
getf :: Brett -> Position  -> Feld
getf (f:fs) p
    | fst f == p = f
    | otherwise  = getf fs p
getf [] p = (p,[])


--- extract cell for position
getc :: Brett -> Position -> Zelle
getc b p = snd (getf b p)

--- compute the list of all positions that belong to the same row as a given position
row :: Position -> [Position]
row p = [z..(z+8)] where z = (p `quot` 9) * 9

--- compute the list of all positions that belong to the same col as a given position
col :: Position -> [Position]
col p = map (c+) rowstarts where c = p `mod` 9

--- compute the list of all positions that belong to the same box as a given position
box :: Position -> [Position]
box p  = map (z+) boxmuster where
    ri = p `div` 27 * 27    -- 0, 27 or 54, depending on row
    ci = p `mod` 9          -- column index 0..8, 0,1,2 is left, 3,4,5 is middle, 6,7,8 is right
    cs = ci `div` 3 * 3     -- 0, 3 or 6
    z  = ri + cs

--- check if candidate set has exactly one member, i.e. field has been solved
single :: Zelle -> Bool
single [_] = true
single _   = false

unsolved :: Zelle -> Bool
unsolved [_] = false
unsolved _   = true

-- list of rows, cols, boxes
allrows = map row rowstarts
allcols = map col colstarts
allboxs = map box boxstarts
allrcb  = zip (repeat "row") allrows
          ++ zip (repeat "col") allcols
          ++ zip (repeat "box") allboxs


containers :: [(Position -> [Position], String)]
containers = [(row, "row"), (col, "col"), (box, "box")]

-- ----------------- PRINTING ------------------------------------
-- printable coordinate of field, upper left is a1, lower right is i9
pname p = packed [chr (ord 'a' + p `mod` 9), chr (ord '1' + p `div` 9)]

-- print board
printb b = mapM_ p1line allrows >> println ""
    where
        p1line row = do
                print (joined "" (map pfld line))
            where line = map (getc b) row

-- print field (brief)
--   ? = no candidate
--   5 = field is 5
--   . = some candidates
pfld [] = "?"
pfld [x] = show x
pfld zs = "0"

-- print initial/final board
result msg b = do
        println ("Result: " ++ msg)
        print   ("Board: ")
        printb b
        return b

res012 b = case concatMap (getc b) [0,1,2] of
    [a,b,c] -> a*100+b*10+c
    _ -> 9999999

-- -------------------------- BOARD ALTERATION ACTIONS ---------------------------------
-- print a message about what is done to the board and return the new board
turnoff1 :: Position -> Zelle -> Brett -> IO Brett
turnoff1 i off b
    | single nc = do
            -- print (pname i)
            -- print ": set to "
            -- print (head nc)
            -- println " (naked single)"
            return newb
    | otherwise = return newb
    where
        cell   = getc b i
        nc     = filter (`notElem` off) cell
        newb   = (i, nc) : [ f | f <- b, fst f != i ]

turnoff :: Int -> Zelle -> String -> Brett -> IO Brett
turnoff i off msg b = do
        -- print (pname i)
        -- print ": set to "
        -- print nc
        -- print " by clearing "
        -- print off
        -- print " "
        -- println  msg
        return newb
    where
        cell   = getc b i
        nc     = filter (`notElem` off) cell
        newb   = (i, nc) : [ f | f <- b, fst f != i ]

turnoffh ps off msg b = foldM toh b ps
    where
        toh b p = turnoff p off msg b

setto :: Position -> Element -> String -> Brett -> IO Brett
setto i n cname b = do
        -- print (pname i)
        -- print ": set to "
        -- print n
        -- print " (hidden single in "
        -- print cname
        -- println ")"
        return newb
    where
        nf     = [n]
        newb   = (i, nf) : [ f | f <- b, fst f != i ]


-- ----------------------------- SOLVING STRATEGIES ---------------------------------------------
-- reduce candidate sets that contains numbers already in same row, col or box
-- This finds (and logs) NAKED SINGLEs in passing.
reduce b = [  turnoff1 p sss | (p,cell) <- b,               -- for each field
                unsolved cell,                              --  with more than 1 candidate
                --       single fields in containers that are candidates of that field
                sss = [ s | (rcb, _) <- containers, [s] <- map (getc b) (rcb p), s `elem` cell],
                sss != [] ]                                     -- collect field index, elements to remove from candidate set

-- look for a number that appears in exactly 1 candidate set of a container
-- this number can go in no other place (HIDDEN SINGLE)
hiddenSingle b = [ setto i n cname |                     -- select index, number, containername
            (cname, rcb) <- allrcb,                 -- FOR rcb IN allrcb
            n <- elements,                          --  FOR n IN elements
            fs     = filter (unsolved • snd) (map (getf b) rcb),
            occurs  = filter ((n `elem`) • snd) fs,
            length occurs == 1,
            (i, _) <- occurs ]

-- look for NAKED PAIRS, TRIPLES, QUADS
nakedPair n b = [ turnoff p t ("(naked tuple in " ++ nm ++ ")") |           -- SELECT pos, tuple, name
            -- n <- [2,3,4],                    //  FOR n IN [2,3,4]
            (nm, rcb) <- allrcb,             --    FOR rcb IN containers
            fs = map (getf b) rcb,              --      let fs = fields for rcb positions
            u  = (fold union [] . filter unsolved . map snd) fs,   -- let u = union of non single candidates
            t <- n `outof` u,                   --      FOR t IN n-tuples
            hit = (filter ((`subset` t) . snd) . filter (unsolved . snd)) fs,
            length hit == n,
            (p, cell) <- fs,
            p `notElem` map fst hit,
            any (`elem` cell) t
            ]

-- look for HIDDEN PAIRS, TRIPLES or QUADS
hiddenPair n b = [ turnoff p off ("(hidden " ++ show t ++ " in " ++ nm ++ ")") |           -- SELECT pos, tuple, name
            -- n <- [2,3,4],                    //  FOR n IN [2,3,4]
            (nm, rcb) <- allrcb,             --    FOR rcb IN containers
            fs = map (getf b) rcb,              --      let fs = fields for rcb positions
            u  = (fold union [] . filter ((>1) . length) . map snd) fs,   -- let u = union of non single candidates
            t <- n `outof` u,                   --      FOR t IN n-tuples
            hit = (filter (any ( `elem` t) . snd) . filter (unsolved . snd)) fs,
            length hit == n,
            off = (fold union [] . map snd) hit `minus` t,
            off != [],
            (p, cell) <- hit,
            ! (cell `subset` t)
            ]

a `subset` b = all (`elem` b) a
a `union`  b = uniq (sort (a ++ b))
a `minus`  b = filter (`notElem` b) a
a `common` b = filter (`elem` b) a
n `outof` as
    | length as < n = []
    | [] <- as      = []
    | 1 >= n        = map (:[]) as
    | (a:bs) <- as  = map (a:) ((n-1) `outof` bs) ++ (n `outof` bs)
    | otherwise     = undefined  -- cannot happen because either as is empty or not

same f a b = b `elem` f a

intersectionlist = [(allboxs, row, "box/row intersection"), (allboxs, col, "box/col intersection"),
                    (allrows ++ allcols, box, "line/box intersection")]
intersections b = [
    turnoff pos [c] reason |    -- SELECT position, candidate, reson
        (from, container, reason) <- intersectionlist,
        rcb <- from,
        fs = (filter (unsolved . snd) . map (getf b)) rcb,        -- fs = fields in from with more than 1 candidate
        c <- (fold union [] • map snd) fs,                          -- FOR c IN union of candidates
        cpos = (map fst • filter ((c `elem`) • snd)) fs,            -- cpos = positions where c occurs
        cpos != [],                                                 -- WHERE cpos is not empty
        all (same container (head cpos)) (tail cpos),               -- WHERE all positions are in the intersection
        -- we can remove all occurences of c that are in container, but not in from
        (pos, cell) <- map (getf b) (container (head cpos)),
        c `elem` cell,
        pos `notElem` rcb ]


-- look for an XY Wing
--  - there exists a cell A with candidates X and Y
--  - there exists a cell B with candidates X and Z that shares a container with A
--  - there exists a cell C with candidates Y and Z that shares a container with A
-- reasoning
--  - if A is X, B will be Z
--  - if A is Y, C will be Z
--  - since A will indeed be X or Y -> B or C will be Z
--  - thus, no cell that can see B and C can be Z
xyWing board = [ turnoff p [z] ("xy wing " ++ pname b ++ " " ++ pname c ++ " because of " ++ pname a) |
        (a, [x,y]) <- board,                            -- there exists a cell a with candidates x and y
        rcba = map (getf board) (row a ++ col a ++ box a),  -- rcba = all fields that share a container with a
        (b, [b1, b2]) <- rcba,
        b != a,
        b1 == x && b2 != y || b2 == x && b1 != y,       -- there exists a cell B with candidates x and z
        z = if b1 == x then b2 else b1,
        (c, [c1, c2]) <- rcba,
        c != a, c!= b,
        c1 == y && c2 == z || c1 == z && c2 == y,       -- there exists a cell C with candidates y and z
        ps = (uniq . sort) ((row b ++ col b ++ box b) `common` (row c ++ col c ++ box c)),
        -- remove z in ps
        (p, cs) <- map (getf board) ps,
        p != b, p != c,
        z `elem` cs ]

-- look for a N-Fish (2: X-Wing, 3: Swordfish, 4: Jellyfish)
-- When all candidates for a particular digit in N rows are located
-- in only N columns, we can eliminate all candidates from those N columns
--  which are not located on those N rows
fish n board = fish "row" allrows row col ++ fish "col" allcols col row where
    fishname 2 = "X-Wing"
    fishname 3 = "Swordfish"
    fishname 4 = "Jellyfish"
    fishname _ = "unknown fish"
    fish nm allrows row col = [ turnoff p [x] (fishname n ++ " in " ++ nm ++ " " ++ show (map (pname . head) rset)) |
        rset <- n `outof` allrows,          -- take n rows (or cols)
        x <- elements,                      -- look for certain number
        rflds = map (filter ((>1) . length . snd) . map (getf board)) rset,       -- unsolved fields in the rowset
        colss  = (map (map (head . col . fst) . filter ((x `elem`) . snd)) rflds),   -- where x occurs in candidates
        all ((>1) . length) colss,         -- x must appear in at least 2 cols
        cols = fold union [] colss,
        length cols == n,
        cstart <- cols,
        (p, cell) <- map (getf board) (col cstart),
        x `elem` cell,
        all (p `notElem`) rset]


-- compute immediate consequences of an assumption of the form (p `IS` e) or (p `ISNOT` e)
conseq board (IS p e) = uniq (sort ([ p `ISNOT` x | x <- getc board p, x != e ] ++
    [ a `ISNOT` e |
        (a,cs) <- map (getf board) (row p ++ col p ++ box p),
        a != p,
        e `elem` cs
    ]))
conseq board (ISNOT p  e) = uniq (sort ([ p `IS` x | cs = getc board p, length cs == 2, x <- cs, x != e ] ++
    [ a `IS` e |
        cp <- [row p, box p, col p],
        as = (filter ((e `elem`) . getc board) . filter (p!=)) cp,
        length as == 1,
        a = head as
    ]))

-- check if two assumptions contradict each other
contradicts (IS a x)    (IS b y)    = a==b && x!=y
contradicts (IS a x)    (ISNOT b y) = a==b && x==y
contradicts (ISNOT a x) (IS b y)    = a==b && x==y
contradicts (ISNOT _ _) (ISNOT _ _) = false

-- get the Position of an Assumption
aPos (IS p _)    = p
aPos (ISNOT p _) = p

-- get List of elements that must be turned off when assumption is true/false
toClear board true  (IS p x)    = filter (x!=) (getc board p)
toClear board false (IS p x)    = [x]
toClear board true  (ISNOT p x) = [x]
toClear board false (ISNOT p x) = filter (x!=) (getc board p)


-- look for assumptions whose implications contradict themself
chain board paths = [ solution a (head cs) (reverse cs) |
        (a, css) <-  paths,
        cs <- take 1 [ cs | cs <- css, contradicts a (head cs) ]
        ]
    where
        solution a c cs = turnoff (aPos a) (toClear board false a) reason where
            reason = "Assumption " ++ show a ++ " implies " ++ show c ++ "\n\t"
                ++ showcs cs ++ "\n\t"
                ++ "Therefore, " ++ show a ++ " must be false."

-- look for an assumption that yields to contradictory implications
-- this assumption must be false
chainContra board paths = [ solution a (reverse pro) (reverse contra) |
        (a, css) <- paths,          -- FOR ALL assumptions "a" with list of conlusions "css"
        (pro, contra) <- take 1 [ (pro, contra) |
            pro <- (uniqBy (using head) . sortBy (comparing head)) css,                 -- FOR ALL conslusion chains "pro"
            c = head pro,               -- LET "c" BE the final conclusion
            contra <- take 1 (filter ((contradicts c) . head) css)   -- THE FIRST conclusion that contradicts c
        ]
      ]
    where
        solution a pro con = turnoff (aPos a) (toClear board false a) reason where
            reason = ("assumption " ++ show a ++ " leads to contradictory conclusions\n\t"
                        ++ showcs pro ++ "\n\t" ++ showcs con)



-- look for a common implication c of some assumptions ai, where at least 1 ai is true
-- so that (a0 OR a1 OR a2 OR ...) IMPLIES c
-- For all cells pi in same container that have x as candidate, we can construct (p0==x OR p1==x OR ... OR pi==x)
-- For a cell p with candidates ci, we can construct (p==c0 OR p==c1)
cellRegionChain board paths = [ solution b as (map head os) |
        as <- cellas ++ regionas,           -- one of as must be true
        iss = filter ((`elem` as) . fst) paths,    -- the implications for as
        (a, ass) <- take 1 iss,             -- implications for first assumption
        fs <- (uniqBy (using head) . sortBy (comparing head)) ass,
        b = head fs,                        -- final conclusions of first assumption
        os = [fs] : map (take 1 . filter ((b==) . head) . snd) (tail iss), -- look for implications with same conclusion
        all ([]!=) os]
    where
        cellas   = [ map (p `IS`) candidates | (p, candidates@(_:_:_)) <- board ]
        regionas = [ map (`IS` e) ps |
            region <- map (map (getf board)) (allrows ++ allcols ++ allboxs),
            e <- elements,
            ps = map fst (filter ((e `elem`) . snd) region),
            length ps > 1 ]
        solution b as oss = turnoff (aPos b) (toClear board true b) reason where
            reason = "all of the assumptions " ++ joined ", " (map show as) ++ " imply " ++ show b ++ "\n\t"
                ++ joined "\n\t" (map (showcs . reverse) oss) ++ "\n\t"
                ++ "One of them must be true, so " ++ show b ++ " must be true."


{-
    Wir brauchen für einige Funktionen eine Datenstruktur wie
        [ (Assumption, [[Assumption]]) ]
    d.i. eine Liste von möglichen Annahmen samt aller Schlußketten.
    Idealerweise sollte die Schlußkette in umgekehrter Reihenfolge vorliegen,
    dann kann man einfach finden:
    - Annahmen, die zum Selbstwiderspruch führen.
    - alles, was aus einer bestimmten Annahme folgt (map (map head) [[a]])
    -...
-}
--- Liste aller Annahmen für ein bestimmtes Brett
assumptions :: Brett -> [Assumption]
assumptions board = [ a |
                (p, cs) <- board,
                !(single cs),
                a <- map (ISNOT p) cs ++ map (IS p) cs ]

consequences :: Brett -> [Assumption] -> [[Assumption]]
consequences board as = map (conseq board) as

acstree :: Brett -> Tree Assumption [Assumption]
acstree board = Tree.fromList (zip as cs)
    where
        as = assumptions  board
        cs = consequences board as

-- bypass maybe on tree lookup
find :: Tree Assumption [Assumption] -> Assumption -> [Assumption]
find t a
    | Just cs <- t.lookup a = cs
    | otherwise = error ("no consequences for " ++ show a)

-- for performance resons, we confine ourselves to implication chains of length 20 per assumption
mkPaths :: Tree Assumption [Assumption] -> [ (Assumption, [[Assumption]]) ]
mkPaths acst = map impl  (keys acst)   -- {[a1], [a2], [a3] ]
    where
        -- [Assumption] -> [(a, [chains, ordered by length]
        impl a = (a, impls [[a]])
        impls ns = (take 1000 • concat • takeUntil null • iterate expandchain) ns
        -- expandchain :: [[Assumption]] -> [[Assumption]]
        expandchain css = [ (n:a:as) |
            (a : as) <- css,               -- list of assumptions
            n <- find acst a,              -- consequences of a
            n `notElem` as                 -- avoid loops
          ]
        -- uni (a:as) = a : uni (filter ((head a !=) • head) as)
        -- uni [] = empty
        -- empty = []


-- ------------------ SOLVE A SUDOKU --------------------------
-- Apply all available strategies until nothing changes anymore
-- Strategy functions are supposed to return a list of
-- functions, which, when applied to a board, give a changed board.
-- When a strategy does not find anything to alter,
-- it returns [], and the next strategy can be tried.
solve b
    | all (single . snd) b       = result "Solved" b
    | any (([]==) . snd) b       = result "not solvable" b
    | res@(_:_) <- reduce b       = apply b res >>=solve       -- compute smallest candidate sets
    -- comment "candidate sets are up to date" = ()
    | res@(_:_) <- hiddenSingle b  = apply b res >>= solve     -- find HIDDEN SINGLES
    -- comment "no more hidden singles" = ()
    | res@(_:_) <- intersections b = apply b res >>= solve     -- find locked candidates
    -- comment "no more intersections" = ()
    | res@(_:_) <- nakedPair 2 b     = apply b res >>= solve     -- find NAKED PAIRS, TRIPLES or QUADRUPELS
    -- comment "no more naked pairs" = ()
    | res@(_:_) <- hiddenPair  2 b   = apply b res >>= solve      -- find HIDDEN PAIRS, TRIPLES or QUADRUPELS
    -- comment "no more hidden pairs" = ()
    -- res@(_:_) <- nakedPair 3 b     = apply b res >>= solve       // find NAKED PAIRS, TRIPLES or QUADRUPELS
    -- | comment "no more naked triples" = ()
    -- res@(_:_) <- hiddenPair  3 b    = apply b res >>= solve      // find HIDDEN PAIRS, TRIPLES or QUADRUPELS
    -- | comment "no more hidden triples" = ()
    -- res@(_:_) <- nakedPair 4 b     = apply b res >>=solve       // find NAKED PAIRS, TRIPLES or QUADRUPELS
    -- | comment "no more naked quadruples" = ()
    -- res@(_:_) <- hiddenPair  4 b    = apply b res >>=solve      // find HIDDEN PAIRS, TRIPLES or QUADRUPELS
    -- | comment "no more hidden quadruples" = ()
    | res@(_:_) <- xyWing b            = apply b res >>=solve      -- find XY WINGS
    -- comment "no more xy wings"       = ()
    | res@(_:_) <- fish 2 b            = apply b res >>=solve      -- find 2-FISH
    -- comment "no more x-wings"        = ()
    -- res@(_:_) <- fish 3 b            = apply b res >>=solve      // find 3-FISH
    -- | comment "no more swordfish"      = ()
    -- res@(_:_) <- fish 4 b            = apply b res >>=solve      // find 4-FISH
    -- | comment "no more jellyfish"      = ()
    -- | comment pcomment                 = ()
    | res@(_:_) <- chain b paths             = apply b (take 9 res) >>= solve  -- find forcing chains
    | res@(_:_) <- cellRegionChain b paths   = apply b (take 9 res) >>= solve  -- find common conclusion for true assumption
    | res@(_:_) <- chainContra b paths       = apply b (take 9 res) >>= solve  -- find assumptions that allow to infer both a and !a
    -- comment "consistent conclusions only"       = ()

    | otherwise = result "ambiguous" b
    where
        apply brd fs = foldM (\b\f -> f b) brd fs
        paths = mkPaths (acstree b)
        -- pcomment = show (length paths) ++ " assumptions with " ++ show (fold (+) 0 (map (length <~ snd) paths))
        --    ++ " implication chains"

-- comment com = do stderr << com << "\n" for false
-- log com     = do stderr << com << "\n" for true

--- turn a string into a row
mkrow :: String -> [Zelle]
mkrow s = mkrow1 xs
    where
        xs = s ++ "---------" -- make sure at least 9 elements
        mkrow1 xs = (take 9 • filter ([]!=) • map f • unpacked) xs
        f x | x >= '1' && x <= '9'  =  [ord x - ord '0']
            | x == ' '  = []    -- ignored
            | otherwise = elements

main ["-h"]    = main []
main ["-help"] = main []
main [] = do
        mapM_ stderr.println [
            "usage: java Sudoku file ...",
            "       java Sudoku position",
            "where position is a 81 char string consisting of digits",
            "One can get such a string by going to",
            "http://www.sudokuoftheday.com/pages/s-o-t-d.php",
            "Right click on the puzzle and open it in new tab",
            "Copy the 81 digits from the URL in the address field of your browser.",
            "",
            "There is also a file with hard sudokus in examples/top95.txt\n"]
        return ()


main [s@#^[0-9\W]{81}$#] = solve board >> return ()
    where
        board = zip positions felder
        felder = decode s

main files = forM_ files sudoku
    where
        sudoku file = do
            br <- openReader file
            lines <- BufferedReader.getLines br
            bs <- process lines
            ss <- mapM (\b -> print "Puzzle: " >> printb b >> solve b) bs
            println ("Euler: " ++ show (sum (map res012 ss)))
            return ()

-- "--3-" => [1..9, 1..9, [3], 1..9]
decode s = map candi (unpacked s) where
        candi c | c >= '1' && c <= '9'  = [(ord c - ord '0')]
                | otherwise = elements
process [] = return []
process (s:ss)
    | length s == 81 = consider b1
    | length s == 9,
      length acht == 8,
      all ((9==) • length) acht = consider b2
    | otherwise = do
            stderr.println ("skipped line: " ++ s)
            process ss
    where
        acht = take 8 ss
        neun = fold (++) "" (s:acht)
        b1 = zip positions (decode s)
        b2 = zip positions (decode neun)
        consider b = do
            -- print "Puzzle: "
            -- printb b
            bs <- process ss
            return (b:bs)

